from .lmm.mistral import DattnMistralForCausalLM, DattnMistralConfig
from .lmm.gemma import DattnGemma2ForCausalLM, DattnGemma2Config


def get_dattn_cls(model_name_or_path):
    if "mistral" in model_name_or_path.lower():
        return DattnMistralForCausalLM
    elif "gemma" in model_name_or_path.lower():
        return DattnGemma2ForCausalLM
    else:
        raise NotImplementedError(f"Unsupported model type: {model_name_or_path}")